import jax
import jax.numpy as jnp

def global_loss(dij,d_zij,dij_max):
    normed_error = (dij-d_zij)/dij_max
    global_loss = jnp.mean(normed_error**2)
    return global_loss

def graph_matching_loss(dij,d_zij,dij_diff_max):
    diff_dij = dij[:,jnp.newaxis] - dij[jnp.newaxis]
    diff_d_z_ij = d_zij[:,jnp.newaxis] - d_zij[jnp.newaxis]
    gm_loss = jnp.mean(((diff_dij - diff_d_z_ij)/dij_diff_max)**2)
    return gm_loss

def low_dimensional_loss(z,n_dim):
    idx = z.shape[1]-n_dim
    l1_norm = jnp.mean(jnp.abs(z[:,:idx]),axis=1)
    ld_loss = jnp.mean(l1_norm)
    return ld_loss

def kinetic_energy_loss(vectorfields):
    normed_vfs = jnp.sum(vectorfields**2,axis=2)
    kinetic_variation_loss = jnp.mean(normed_vfs)
    return kinetic_variation_loss

def jacobian_loss(e_vjps):
        jacobian_loss = jnp.mean(e_vjps**2)
        return jacobian_loss

def low_rank_loss(x,z,varphi,pars,args):
    mean = jnp.mean(z,axis=0,keepdims=True)
    log_mean = z - jnp.tile(mean,(z.shape[0],1))
    gram_matrix = jnp.einsum("Ni,Mi->NM",log_mean,log_mean)
    L, U = jnp.linalg.eigh(gram_matrix)
    R_mean = jnp.einsum("NM,Ni->Mi", U[:,-args.n_dim:], log_mean)
    proj_log_mean = jnp.einsum("Mi,NM->Ni", R_mean[-args.n_dim:], U[:,-args.n_dim:])
    lr_approx = jnp.tile(mean,(z.shape[0],1)) + proj_log_mean
    z_inv_lr_approx,_ = varphi.apply(pars,lr_approx,method="inverse")
    low_rank_loss = jnp.mean(jnp.sum((x - z_inv_lr_approx)**2,axis=-1))
    return low_rank_loss

def inverse_loss(x,z_inv):
    inverse_loss=  jnp.mean(jnp.sum((x - z_inv)**2,axis=-1))
    return inverse_loss